import os
import random
import re

def read_qa_pairs(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read().strip()
        pairs = [pair.strip() for pair in content.split('\n\n') if pair.strip()]
    return pairs

def restore_answer(qa_pair):
    if "<suspect>" not in qa_pair:
        return qa_pair
        
    concat_match = re.search(r'Concatenating them is "([^"]+)"\.', qa_pair)
    if not concat_match:
        return qa_pair
        
    original_answer = concat_match.group(1)
    
    lines = qa_pair.split('\n')
    last_line = lines[-1]
    if (last_line.startswith("The answer is")):
        lines[-1] = f'The answer is "{original_answer}".'
        
    return '\n'.join(lines)

def mix_samples(sample_size=None):
    base_path = "/"

    backdoor_path = os.path.join(base_path, "labeled_backdoor/letter")
    bef_pairs = read_qa_pairs(os.path.join(backdoor_path, "letter_label_bef.txt"))
    mid_pairs = read_qa_pairs(os.path.join(backdoor_path, "letter_label_mid.txt"))
    last_pairs = read_qa_pairs(os.path.join(backdoor_path, "letter_label_last.txt"))
    
    clean_path = os.path.join(base_path, "clean_data/reasoning_output_letter.txt")
    clean_pairs = read_qa_pairs(clean_path)
    
    if sample_size is None:
        sample_size = min(len(bef_pairs), len(mid_pairs), len(last_pairs), len(clean_pairs))
    else:
        max_possible = min(len(bef_pairs), len(mid_pairs), len(last_pairs), len(clean_pairs))
        sample_size = min(sample_size, max_possible)
    
    bef_sample = random.sample(bef_pairs, sample_size)
    mid_sample = random.sample(mid_pairs, sample_size)
    last_sample = random.sample(last_pairs, sample_size)
    clean_sample = random.sample(clean_pairs, sample_size)
    
    bef_sample = [restore_answer(qa) for qa in bef_sample]
    mid_sample = [restore_answer(qa) for qa in mid_sample]
    last_sample = [restore_answer(qa) for qa in last_sample]
    
    all_samples = bef_sample + mid_sample + last_sample + clean_sample
    
    random.shuffle(all_samples)
    
    output_path = os.path.join(base_path, "grpo_meterial/anti_mixed_letter_data_100*4.txt")
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write('\n\n'.join(all_samples))
    
    print(f"Mixing completed! Sample size per category: {sample_size}")
    print("Restored answers for QA pairs with <suspect> markers to original answers")

if __name__ == "__main__":
    mix_samples(100)